{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "0e499e90-7a6d-4fab-8aab-31a4df417601",
   "metadata": {},
   "source": [
    "# SQL Database\n",
    "\n",
    "This notebook showcases an agent designed to interact with a `SQL` databases. \n",
    "\n",
    "The agent builds from [SQLDatabaseChain](https://python.langchain.com/docs/use_cases/tabular/sqlite). \n",
    "\n",
    "It is designed to answer more general questions about a database, as well as recover from errors.\n",
    "\n",
    "Note that, as this agent is in active development, all answers might not be correct. \n",
    "\n",
    "Additionally, it is not guaranteed that the agent won't perform DML statements on your database given certain questions. \n",
    "\n",
    "Be careful running it on sensitive data!|\n",
    "\n",
    "This uses the example `Chinook` database. \n",
    "\n",
    "To set it up follow [these instructions](https://database.guide/2-sample-databases-sqlite/) and place the .db file in a notebooks folder at the root of this repository."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "39ee5811-7f7d-47a9-941a-83dcf4b74239",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[33mWARNING: There was an error checking the latest version of pip.\u001b[0m\u001b[33m\n",
      "\u001b[0mNote: you may need to restart the kernel to use updated packages.\n"
     ]
    }
   ],
   "source": [
    "%pip install --upgrade --quiet  langchain langchain-community langchain-experimental"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "804533b1-2f16-497b-821b-c82d67fcf7b6",
   "metadata": {},
   "source": [
    "## Initialize our database"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "42bd5a41-672a-4a53-b70a-2f0c0555758c",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_community.utilities.sql_database import SQLDatabase\n",
    "\n",
    "db = SQLDatabase.from_uri(\"sqlite:///Chinook.db\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fabfc8b0-2354-4f4d-b334-cd8ae3a12b0a",
   "metadata": {},
   "source": [
    "## Quickstart\n",
    "\n",
    "Following the [SQL use case docs](https://python.langchain.com/docs/use_cases/sql/agents), we can use the `create_sql_agent` helper."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "c0504b5c-72a8-4407-9056-944615f1d480",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_community.agent_toolkits import create_sql_agent\n",
    "from langchain_openai import ChatOpenAI\n",
    "\n",
    "llm = ChatOpenAI(model=\"gpt-3.5-turbo\", temperature=0)\n",
    "agent_executor = create_sql_agent(llm, db=db, agent_type=\"openai-tools\", verbose=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "75f2f603-6004-4363-a781-a6e352124ffa",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
      "\u001b[32;1m\u001b[1;3m\n",
      "Invoking: `sql_db_list_tables` with `{}`\n",
      "\n",
      "\n",
      "\u001b[0m\u001b[38;5;200m\u001b[1;3mAlbum, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track\u001b[0m\u001b[32;1m\u001b[1;3m\n",
      "Invoking: `sql_db_schema` with `Invoice,Customer`\n",
      "\n",
      "\n",
      "\u001b[0m\u001b[33;1m\u001b[1;3m\n",
      "CREATE TABLE \"Customer\" (\n",
      "\t\"CustomerId\" INTEGER NOT NULL, \n",
      "\t\"FirstName\" NVARCHAR(40) NOT NULL, \n",
      "\t\"LastName\" NVARCHAR(20) NOT NULL, \n",
      "\t\"Company\" NVARCHAR(80), \n",
      "\t\"Address\" NVARCHAR(70), \n",
      "\t\"City\" NVARCHAR(40), \n",
      "\t\"State\" NVARCHAR(40), \n",
      "\t\"Country\" NVARCHAR(40), \n",
      "\t\"PostalCode\" NVARCHAR(10), \n",
      "\t\"Phone\" NVARCHAR(24), \n",
      "\t\"Fax\" NVARCHAR(24), \n",
      "\t\"Email\" NVARCHAR(60) NOT NULL, \n",
      "\t\"SupportRepId\" INTEGER, \n",
      "\tPRIMARY KEY (\"CustomerId\"), \n",
      "\tFOREIGN KEY(\"SupportRepId\") REFERENCES \"Employee\" (\"EmployeeId\")\n",
      ")\n",
      "\n",
      "/*\n",
      "3 rows from Customer table:\n",
      "CustomerId\tFirstName\tLastName\tCompany\tAddress\tCity\tState\tCountry\tPostalCode\tPhone\tFax\tEmail\tSupportRepId\n",
      "1\tLuís\tGonçalves\tEmbraer - Empresa Brasileira de Aeronáutica S.A.\tAv. Brigadeiro Faria Lima, 2170\tSão José dos Campos\tSP\tBrazil\t12227-000\t+55 (12) 3923-5555\t+55 (12) 3923-5566\tluisg@embraer.com.br\t3\n",
      "2\tLeonie\tKöhler\tNone\tTheodor-Heuss-Straße 34\tStuttgart\tNone\tGermany\t70174\t+49 0711 2842222\tNone\tleonekohler@surfeu.de\t5\n",
      "3\tFrançois\tTremblay\tNone\t1498 rue Bélanger\tMontréal\tQC\tCanada\tH2G 1A7\t+1 (514) 721-4711\tNone\tftremblay@gmail.com\t3\n",
      "*/\n",
      "\n",
      "\n",
      "CREATE TABLE \"Invoice\" (\n",
      "\t\"InvoiceId\" INTEGER NOT NULL, \n",
      "\t\"CustomerId\" INTEGER NOT NULL, \n",
      "\t\"InvoiceDate\" DATETIME NOT NULL, \n",
      "\t\"BillingAddress\" NVARCHAR(70), \n",
      "\t\"BillingCity\" NVARCHAR(40), \n",
      "\t\"BillingState\" NVARCHAR(40), \n",
      "\t\"BillingCountry\" NVARCHAR(40), \n",
      "\t\"BillingPostalCode\" NVARCHAR(10), \n",
      "\t\"Total\" NUMERIC(10, 2) NOT NULL, \n",
      "\tPRIMARY KEY (\"InvoiceId\"), \n",
      "\tFOREIGN KEY(\"CustomerId\") REFERENCES \"Customer\" (\"CustomerId\")\n",
      ")\n",
      "\n",
      "/*\n",
      "3 rows from Invoice table:\n",
      "InvoiceId\tCustomerId\tInvoiceDate\tBillingAddress\tBillingCity\tBillingState\tBillingCountry\tBillingPostalCode\tTotal\n",
      "1\t2\t2009-01-01 00:00:00\tTheodor-Heuss-Straße 34\tStuttgart\tNone\tGermany\t70174\t1.98\n",
      "2\t4\t2009-01-02 00:00:00\tUllevålsveien 14\tOslo\tNone\tNorway\t0171\t3.96\n",
      "3\t8\t2009-01-03 00:00:00\tGrétrystraat 63\tBrussels\tNone\tBelgium\t1000\t5.94\n",
      "*/\u001b[0m\u001b[32;1m\u001b[1;3m\n",
      "Invoking: `sql_db_query` with `SELECT c.Country, SUM(i.Total) AS TotalSales FROM Invoice i JOIN Customer c ON i.CustomerId = c.CustomerId GROUP BY c.Country ORDER BY TotalSales DESC LIMIT 10;`\n",
      "responded: To list the total sales per country, I can query the \"Invoice\" and \"Customer\" tables. I will join these tables on the \"CustomerId\" column and group the results by the \"BillingCountry\" column. Then, I will calculate the sum of the \"Total\" column to get the total sales per country. Finally, I will order the results in descending order of the total sales.\n",
      "\n",
      "Here is the SQL query:\n",
      "\n",
      "```sql\n",
      "SELECT c.Country, SUM(i.Total) AS TotalSales\n",
      "FROM Invoice i\n",
      "JOIN Customer c ON i.CustomerId = c.CustomerId\n",
      "GROUP BY c.Country\n",
      "ORDER BY TotalSales DESC\n",
      "LIMIT 10;\n",
      "```\n",
      "\n",
      "Now, I will execute this query to get the total sales per country.\n",
      "\n",
      "\u001b[0m\u001b[36;1m\u001b[1;3m[('USA', 523.0600000000003), ('Canada', 303.9599999999999), ('France', 195.09999999999994), ('Brazil', 190.09999999999997), ('Germany', 156.48), ('United Kingdom', 112.85999999999999), ('Czech Republic', 90.24000000000001), ('Portugal', 77.23999999999998), ('India', 75.25999999999999), ('Chile', 46.62)]\u001b[0m\u001b[32;1m\u001b[1;3mThe total sales per country are as follows:\n",
      "\n",
      "1. USA: $523.06\n",
      "2. Canada: $303.96\n",
      "3. France: $195.10\n",
      "4. Brazil: $190.10\n",
      "5. Germany: $156.48\n",
      "6. United Kingdom: $112.86\n",
      "7. Czech Republic: $90.24\n",
      "8. Portugal: $77.24\n",
      "9. India: $75.26\n",
      "10. Chile: $46.62\n",
      "\n",
      "To answer the second question, the country whose customers spent the most is the USA, with a total sales of $523.06.\u001b[0m\n",
      "\n",
      "\u001b[1m> Finished chain.\u001b[0m\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'input': \"List the total sales per country. Which country's customers spent the most?\",\n",
       " 'output': 'The total sales per country are as follows:\\n\\n1. USA: $523.06\\n2. Canada: $303.96\\n3. France: $195.10\\n4. Brazil: $190.10\\n5. Germany: $156.48\\n6. United Kingdom: $112.86\\n7. Czech Republic: $90.24\\n8. Portugal: $77.24\\n9. India: $75.26\\n10. Chile: $46.62\\n\\nTo answer the second question, the country whose customers spent the most is the USA, with a total sales of $523.06.'}"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "agent_executor.invoke(\n",
    "    \"List the total sales per country. Which country's customers spent the most?\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ec927ac6-9b2a-4e8a-9a6e-3e429191875c",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Toolkit\n",
    "\n",
    "We can look at what runs under the hood with the `create_sql_agent` helper above.\n",
    "\n",
    "We can also highlight how the toolkit is used with the agent specifically."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "056ddd8a-9835-41ef-84a3-6b680c22248d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_community.agent_toolkits import SQLDatabaseToolkit\n",
    "from langchain_openai import ChatOpenAI\n",
    "\n",
    "toolkit = SQLDatabaseToolkit(db=db, llm=ChatOpenAI(temperature=0))\n",
    "context = toolkit.get_context()\n",
    "tools = toolkit.get_tools()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "08a653b1-02f2-45f3-b357-acff5a6b3869",
   "metadata": {},
   "source": [
    "### Use SQLDatabaseToolkit within an Agent\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "3f673e63-49bf-4112-b74f-d7baa91ed81f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_community.agent_toolkits.sql.prompt import SQL_FUNCTIONS_SUFFIX\n",
    "from langchain_core.messages import AIMessage, SystemMessage\n",
    "from langchain_core.prompts.chat import (\n",
    "    ChatPromptTemplate,\n",
    "    HumanMessagePromptTemplate,\n",
    "    MessagesPlaceholder,\n",
    ")\n",
    "\n",
    "messages = [\n",
    "    HumanMessagePromptTemplate.from_template(\"{input}\"),\n",
    "    AIMessage(content=SQL_FUNCTIONS_SUFFIX),\n",
    "    MessagesPlaceholder(variable_name=\"agent_scratchpad\"),\n",
    "]\n",
    "\n",
    "prompt = ChatPromptTemplate.from_messages(messages)\n",
    "prompt = prompt.partial(**context)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "83f57336-1757-4a54-bf79-591c0a97e6a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain.agents import create_openai_tools_agent\n",
    "from langchain.agents.agent import AgentExecutor\n",
    "\n",
    "llm = ChatOpenAI(model=\"gpt-3.5-turbo\", temperature=0)\n",
    "\n",
    "agent = create_openai_tools_agent(llm, tools, prompt)\n",
    "\n",
    "agent_executor = AgentExecutor(\n",
    "    agent=agent,\n",
    "    tools=toolkit.get_tools(),\n",
    "    verbose=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f77805fb-c05e-408c-9b69-6d2a774c208e",
   "metadata": {},
   "source": [
    "## Disclaimer ⚠️\n",
    "\n",
    "The query chain may generate insert/update/delete queries. When this is not expected, use a custom prompt or create a SQL users without write permissions.\n",
    "\n",
    "The final user might overload your SQL database by asking a simple question such as \"run the biggest query possible\". The generated query might look like:\n",
    "\n",
    "```sql\n",
    "SELECT * FROM \"public\".\"users\"\n",
    "    JOIN \"public\".\"user_permissions\" ON \"public\".\"users\".id = \"public\".\"user_permissions\".user_id\n",
    "    JOIN \"public\".\"projects\" ON \"public\".\"users\".id = \"public\".\"projects\".user_id\n",
    "    JOIN \"public\".\"events\" ON \"public\".\"projects\".id = \"public\".\"events\".project_id;\n",
    "```\n",
    "\n",
    "For a transactional SQL database, if one of the table above contains millions of rows, the query might cause trouble to other applications using the same database.\n",
    "\n",
    "Most datawarehouse oriented databases support user-level quota, for limiting resource usage.\n",
    "\n",
    "## Example: describing a table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "f740fbef-d635-4121-856d-33593ae6bfc0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
      "\u001b[32;1m\u001b[1;3m\n",
      "Invoking: `sql_db_list_tables` with ``\n",
      "\n",
      "\n",
      "\u001b[0m\u001b[38;5;200m\u001b[1;3mAlbum, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track\u001b[0m\u001b[32;1m\u001b[1;3m\n",
      "Invoking: `sql_db_schema` with `Playlist`\n",
      "\n",
      "\n",
      "\u001b[0m\u001b[33;1m\u001b[1;3m\n",
      "CREATE TABLE \"Playlist\" (\n",
      "\t\"PlaylistId\" INTEGER NOT NULL, \n",
      "\t\"Name\" NVARCHAR(120), \n",
      "\tPRIMARY KEY (\"PlaylistId\")\n",
      ")\n",
      "\n",
      "/*\n",
      "3 rows from Playlist table:\n",
      "PlaylistId\tName\n",
      "1\tMusic\n",
      "2\tMovies\n",
      "3\tTV Shows\n",
      "*/\u001b[0m\u001b[32;1m\u001b[1;3mThe schema of the \"Playlist\" table is as follows:\n",
      "\n",
      "- PlaylistId: INTEGER (Primary Key)\n",
      "- Name: NVARCHAR(120)\n",
      "\n",
      "Here are three sample rows from the \"Playlist\" table:\n",
      "\n",
      "| PlaylistId | Name      |\n",
      "|------------|-----------|\n",
      "| 1          | Music     |\n",
      "| 2          | Movies    |\n",
      "| 3          | TV Shows  |\u001b[0m\n",
      "\n",
      "\u001b[1m> Finished chain.\u001b[0m\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'input': 'Describe the schema of the playlist table',\n",
       " 'output': 'The schema of the \"Playlist\" table is as follows:\\n\\n- PlaylistId: INTEGER (Primary Key)\\n- Name: NVARCHAR(120)\\n\\nHere are three sample rows from the \"Playlist\" table:\\n\\n| PlaylistId | Name      |\\n|------------|-----------|\\n| 1          | Music     |\\n| 2          | Movies    |\\n| 3          | TV Shows  |'}"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "agent_executor.invoke({\"input\": \"Describe the schema of the playlist table\"})"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9abcfe8e-1868-42a4-8345-ad2d9b44c681",
   "metadata": {},
   "source": [
    "## Example: describing a table, recovering from an error\n",
    "\n",
    "In this example, the agent tries to search for a table, `playlists`, that doesn't exist, but finds the next best result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "d78ddb6e-f499-4e90-9604-fe22a9c3592e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
      "\u001b[32;1m\u001b[1;3m\n",
      "Invoking: `sql_db_list_tables` with ``\n",
      "\n",
      "\n",
      "\u001b[0m\u001b[38;5;200m\u001b[1;3mAlbum, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track\u001b[0m\u001b[32;1m\u001b[1;3m\n",
      "Invoking: `sql_db_schema` with `Playlist`\n",
      "\n",
      "\n",
      "\u001b[0m\u001b[33;1m\u001b[1;3m\n",
      "CREATE TABLE \"Playlist\" (\n",
      "\t\"PlaylistId\" INTEGER NOT NULL, \n",
      "\t\"Name\" NVARCHAR(120), \n",
      "\tPRIMARY KEY (\"PlaylistId\")\n",
      ")\n",
      "\n",
      "/*\n",
      "3 rows from Playlist table:\n",
      "PlaylistId\tName\n",
      "1\tMusic\n",
      "2\tMovies\n",
      "3\tTV Shows\n",
      "*/\u001b[0m\u001b[32;1m\u001b[1;3mThe \"Playlists\" table has the following schema:\n",
      "\n",
      "- PlaylistId: INTEGER (Primary Key)\n",
      "- Name: NVARCHAR(120)\n",
      "\n",
      "Here are three sample rows from the \"Playlists\" table:\n",
      "\n",
      "| PlaylistId | Name      |\n",
      "|------------|-----------|\n",
      "| 1          | Music     |\n",
      "| 2          | Movies    |\n",
      "| 3          | TV Shows  |\u001b[0m\n",
      "\n",
      "\u001b[1m> Finished chain.\u001b[0m\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'input': 'Describe the playlists table',\n",
       " 'output': 'The \"Playlists\" table has the following schema:\\n\\n- PlaylistId: INTEGER (Primary Key)\\n- Name: NVARCHAR(120)\\n\\nHere are three sample rows from the \"Playlists\" table:\\n\\n| PlaylistId | Name      |\\n|------------|-----------|\\n| 1          | Music     |\\n| 2          | Movies    |\\n| 3          | TV Shows  |'}"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "agent_executor.invoke({\"input\": \"Describe the playlists table\"})"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6fbc26af-97e4-4a21-82aa-48bdc992da26",
   "metadata": {},
   "source": [
    "## Example: running queries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "5457af9a-037d-4dba-9316-0abfe841f114",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
      "\u001b[32;1m\u001b[1;3m\n",
      "Invoking: `sql_db_list_tables` with ``\n",
      "\n",
      "\n",
      "\u001b[0m\u001b[38;5;200m\u001b[1;3mAlbum, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track\u001b[0m\u001b[32;1m\u001b[1;3m\n",
      "Invoking: `sql_db_schema` with `Invoice, Customer`\n",
      "\n",
      "\n",
      "\u001b[0m\u001b[33;1m\u001b[1;3m\n",
      "CREATE TABLE \"Customer\" (\n",
      "\t\"CustomerId\" INTEGER NOT NULL, \n",
      "\t\"FirstName\" NVARCHAR(40) NOT NULL, \n",
      "\t\"LastName\" NVARCHAR(20) NOT NULL, \n",
      "\t\"Company\" NVARCHAR(80), \n",
      "\t\"Address\" NVARCHAR(70), \n",
      "\t\"City\" NVARCHAR(40), \n",
      "\t\"State\" NVARCHAR(40), \n",
      "\t\"Country\" NVARCHAR(40), \n",
      "\t\"PostalCode\" NVARCHAR(10), \n",
      "\t\"Phone\" NVARCHAR(24), \n",
      "\t\"Fax\" NVARCHAR(24), \n",
      "\t\"Email\" NVARCHAR(60) NOT NULL, \n",
      "\t\"SupportRepId\" INTEGER, \n",
      "\tPRIMARY KEY (\"CustomerId\"), \n",
      "\tFOREIGN KEY(\"SupportRepId\") REFERENCES \"Employee\" (\"EmployeeId\")\n",
      ")\n",
      "\n",
      "/*\n",
      "3 rows from Customer table:\n",
      "CustomerId\tFirstName\tLastName\tCompany\tAddress\tCity\tState\tCountry\tPostalCode\tPhone\tFax\tEmail\tSupportRepId\n",
      "1\tLuís\tGonçalves\tEmbraer - Empresa Brasileira de Aeronáutica S.A.\tAv. Brigadeiro Faria Lima, 2170\tSão José dos Campos\tSP\tBrazil\t12227-000\t+55 (12) 3923-5555\t+55 (12) 3923-5566\tluisg@embraer.com.br\t3\n",
      "2\tLeonie\tKöhler\tNone\tTheodor-Heuss-Straße 34\tStuttgart\tNone\tGermany\t70174\t+49 0711 2842222\tNone\tleonekohler@surfeu.de\t5\n",
      "3\tFrançois\tTremblay\tNone\t1498 rue Bélanger\tMontréal\tQC\tCanada\tH2G 1A7\t+1 (514) 721-4711\tNone\tftremblay@gmail.com\t3\n",
      "*/\n",
      "\n",
      "\n",
      "CREATE TABLE \"Invoice\" (\n",
      "\t\"InvoiceId\" INTEGER NOT NULL, \n",
      "\t\"CustomerId\" INTEGER NOT NULL, \n",
      "\t\"InvoiceDate\" DATETIME NOT NULL, \n",
      "\t\"BillingAddress\" NVARCHAR(70), \n",
      "\t\"BillingCity\" NVARCHAR(40), \n",
      "\t\"BillingState\" NVARCHAR(40), \n",
      "\t\"BillingCountry\" NVARCHAR(40), \n",
      "\t\"BillingPostalCode\" NVARCHAR(10), \n",
      "\t\"Total\" NUMERIC(10, 2) NOT NULL, \n",
      "\tPRIMARY KEY (\"InvoiceId\"), \n",
      "\tFOREIGN KEY(\"CustomerId\") REFERENCES \"Customer\" (\"CustomerId\")\n",
      ")\n",
      "\n",
      "/*\n",
      "3 rows from Invoice table:\n",
      "InvoiceId\tCustomerId\tInvoiceDate\tBillingAddress\tBillingCity\tBillingState\tBillingCountry\tBillingPostalCode\tTotal\n",
      "1\t2\t2009-01-01 00:00:00\tTheodor-Heuss-Straße 34\tStuttgart\tNone\tGermany\t70174\t1.98\n",
      "2\t4\t2009-01-02 00:00:00\tUllevålsveien 14\tOslo\tNone\tNorway\t0171\t3.96\n",
      "3\t8\t2009-01-03 00:00:00\tGrétrystraat 63\tBrussels\tNone\tBelgium\t1000\t5.94\n",
      "*/\u001b[0m\u001b[32;1m\u001b[1;3m\n",
      "Invoking: `sql_db_query` with `SELECT c.Country, SUM(i.Total) AS TotalSales FROM Customer c JOIN Invoice i ON c.CustomerId = i.CustomerId GROUP BY c.Country ORDER BY TotalSales DESC LIMIT 1;`\n",
      "responded: Based on the schema of the `Customer` and `Invoice` tables, we can join these two tables on the `CustomerId` column to get the total sales per customer. Then, we can group the results by the `Country` column to get the total sales per country. Finally, we can sort the results in descending order of total sales and select the country with the highest total sales.\n",
      "\n",
      "Here is the SQL query to achieve this:\n",
      "\n",
      "```sql\n",
      "SELECT c.Country, SUM(i.Total) AS TotalSales\n",
      "FROM Customer c\n",
      "JOIN Invoice i ON c.CustomerId = i.CustomerId\n",
      "GROUP BY c.Country\n",
      "ORDER BY TotalSales DESC\n",
      "LIMIT 1;\n",
      "```\n",
      "\n",
      "Let me execute this query to find out which country's customers spent the most.\n",
      "\n",
      "\u001b[0m\u001b[36;1m\u001b[1;3m[('USA', 523.0600000000003)]\u001b[0m\u001b[32;1m\u001b[1;3mThe country whose customers spent the most is the USA, with a total sales amount of $523.06.\u001b[0m\n",
      "\n",
      "\u001b[1m> Finished chain.\u001b[0m\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'input': \"List the total sales per country. Which country's customers spent the most?\",\n",
       " 'output': 'The country whose customers spent the most is the USA, with a total sales amount of $523.06.'}"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "agent_executor.invoke(\n",
    "    {\n",
    "        \"input\": \"List the total sales per country. Which country's customers spent the most?\"\n",
    "    }\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7c7503b5-d9d9-4faa-b064-29fcdb5ff213",
   "metadata": {},
   "source": [
    "## Recovering from an error\n",
    "\n",
    "In this example, the agent is able to recover from an error after initially trying to access an attribute (`Track.ArtistId`) which doesn't exist."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "4d8d8933-c1e5-4ba9-8143-1059606c4adb",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
      "\u001b[32;1m\u001b[1;3m\n",
      "Invoking: `sql_db_list_tables` with ``\n",
      "\n",
      "\n",
      "\u001b[0m\u001b[38;5;200m\u001b[1;3mAlbum, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track\u001b[0m\u001b[32;1m\u001b[1;3m\n",
      "Invoking: `sql_db_schema` with `InvoiceLine`\n",
      "\n",
      "\n",
      "\u001b[0m\u001b[33;1m\u001b[1;3m\n",
      "CREATE TABLE \"InvoiceLine\" (\n",
      "\t\"InvoiceLineId\" INTEGER NOT NULL, \n",
      "\t\"InvoiceId\" INTEGER NOT NULL, \n",
      "\t\"TrackId\" INTEGER NOT NULL, \n",
      "\t\"UnitPrice\" NUMERIC(10, 2) NOT NULL, \n",
      "\t\"Quantity\" INTEGER NOT NULL, \n",
      "\tPRIMARY KEY (\"InvoiceLineId\"), \n",
      "\tFOREIGN KEY(\"TrackId\") REFERENCES \"Track\" (\"TrackId\"), \n",
      "\tFOREIGN KEY(\"InvoiceId\") REFERENCES \"Invoice\" (\"InvoiceId\")\n",
      ")\n",
      "\n",
      "/*\n",
      "3 rows from InvoiceLine table:\n",
      "InvoiceLineId\tInvoiceId\tTrackId\tUnitPrice\tQuantity\n",
      "1\t1\t2\t0.99\t1\n",
      "2\t1\t4\t0.99\t1\n",
      "3\t2\t6\t0.99\t1\n",
      "*/\u001b[0m\u001b[32;1m\u001b[1;3m\n",
      "Invoking: `sql_db_schema` with `Artist`\n",
      "\n",
      "\n",
      "\u001b[0m\u001b[33;1m\u001b[1;3m\n",
      "CREATE TABLE \"Artist\" (\n",
      "\t\"ArtistId\" INTEGER NOT NULL, \n",
      "\t\"Name\" NVARCHAR(120), \n",
      "\tPRIMARY KEY (\"ArtistId\")\n",
      ")\n",
      "\n",
      "/*\n",
      "3 rows from Artist table:\n",
      "ArtistId\tName\n",
      "1\tAC/DC\n",
      "2\tAccept\n",
      "3\tAerosmith\n",
      "*/\u001b[0m\u001b[32;1m\u001b[1;3m\n",
      "Invoking: `sql_db_query` with `SELECT Artist.Name, SUM(InvoiceLine.Quantity) AS TotalQuantity FROM Artist JOIN Album ON Artist.ArtistId = Album.ArtistId JOIN Track ON Album.AlbumId = Track.AlbumId JOIN InvoiceLine ON Track.TrackId = InvoiceLine.TrackId GROUP BY Artist.Name ORDER BY TotalQuantity DESC LIMIT 3`\n",
      "\n",
      "\n",
      "\u001b[0m\u001b[36;1m\u001b[1;3m[('Iron Maiden', 140), ('U2', 107), ('Metallica', 91)]\u001b[0m\u001b[32;1m\u001b[1;3mThe top 3 best selling artists are:\n",
      "\n",
      "1. Iron Maiden - 140 total quantity sold\n",
      "2. U2 - 107 total quantity sold\n",
      "3. Metallica - 91 total quantity sold\u001b[0m\n",
      "\n",
      "\u001b[1m> Finished chain.\u001b[0m\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'input': 'Who are the top 3 best selling artists?',\n",
       " 'output': 'The top 3 best selling artists are:\\n\\n1. Iron Maiden - 140 total quantity sold\\n2. U2 - 107 total quantity sold\\n3. Metallica - 91 total quantity sold'}"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "agent_executor.invoke({\"input\": \"Who are the top 3 best selling artists?\"})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0ca6724d-89ee-49eb-8cdb-5cf66d242a83",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
